import os
import numpy as np
from PIL import Image
import json

from model_service.hiai_model_service import HiaiBaseService
from hiai.nntensor_list import NNTensorList
from hiai.nn_tensor_lib import NNTensor

current_dir = os.path.dirname(os.path.abspath(__file__))
with open(os.path.join(current_dir, 'index'), 'r') as f:
  index_map = json.loads(f.read())
class_names = index_map['labels_list']
image_shape = index_map['image_shape']

net_h = int(image_shape[0])
net_w = int(image_shape[1])

class_num = len(class_names)

stride_list = [8, 16, 32]
anchors_1 = np.array([[10, 13], [16, 30], [33, 23]]) / stride_list[0]
anchors_2 = np.array([[30, 61], [62, 45], [59, 119]]) / stride_list[1]
anchors_3 = np.array([[116, 90], [156, 198], [163, 326]]) / stride_list[2]
anchor_list = [anchors_1, anchors_2, anchors_3]

conf_threshold = 0.3
iou_threshold = 0.4

colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (0, 255, 255), (255, 0, 255), (255, 255, 0)]


def preprocess(image, aipp_flag=True):
  img_w, img_h = image.size

  scale = min(float(net_w) / float(img_w), float(net_h) / float(img_h))
  new_w = int(img_w * scale)
  new_h = int(img_h * scale)

  shift_x = (net_w - new_w) // 2
  shift_y = (net_h - new_h) // 2
  shift_x_ratio = (net_w - new_w) / 2.0 / net_w
  shift_y_ratio = (net_h - new_h) / 2.0 / net_h

  image_ = image.resize((new_w, new_h))

  if aipp_flag:
    new_image = np.zeros((net_h, net_w, 3), np.uint8)
  else:
    new_image = np.zeros((net_h, net_w, 3), np.float32)
  new_image.fill(128)

  new_image[shift_y: new_h + shift_y, shift_x: new_w + shift_x, :] = np.array(image_)

  if not aipp_flag:
    new_image /= 255.

  return new_image, img_w, img_h, new_w, new_h, shift_x_ratio, shift_y_ratio


def overlap(x1, x2, x3, x4):
  left = max(x1, x3)
  right = min(x2, x4)
  return right - left


def cal_iou(box, truth):
  w = overlap(box[0], box[2], truth[0], truth[2])
  h = overlap(box[1], box[3], truth[1], truth[3])
  if w <= 0 or h <= 0:
    return 0
  inter_area = w * h
  union_area = (box[2] - box[0]) * (box[3] - box[1]) + (truth[2] - truth[0]) * (truth[3] - truth[1]) - inter_area
  return inter_area * 1.0 / union_area


def apply_nms(all_boxes, thres):
  res = []

  for cls in range(class_num):
    cls_bboxes = all_boxes[cls]
    sorted_boxes = sorted(cls_bboxes, key=lambda d: d[5])[::-1]

    p = dict()
    for i in range(len(sorted_boxes)):
      if i in p:
        continue

      truth = sorted_boxes[i]
      for j in range(i + 1, len(sorted_boxes)):
        if j in p:
          continue
        box = sorted_boxes[j]
        iou = cal_iou(box, truth)
        if iou >= thres:
          p[j] = 1

    for i in range(len(sorted_boxes)):
      if i not in p:
        res.append(sorted_boxes[i])
  return res


def decode_bbox(conv_output, anchors, img_w, img_h, x_scale, y_scale, shift_x_ratio, shift_y_ratio):
  def _sigmoid(x):
    s = 1 / (1 + np.exp(-x))
    return s

  _, h, w = conv_output.shape
  pred = conv_output.transpose((1, 2, 0)).reshape((h * w, 3, 5 + class_num))

  pred[..., 4:] = _sigmoid(pred[..., 4:])
  pred[..., 0] = (_sigmoid(pred[..., 0]) + np.tile(range(w), (3, h)).transpose((1, 0))) / w
  pred[..., 1] = (_sigmoid(pred[..., 1]) + np.tile(np.repeat(range(h), w), (3, 1)).transpose((1, 0))) / h
  pred[..., 2] = np.exp(pred[..., 2]) * anchors[:, 0:1].transpose((1, 0)) / w
  pred[..., 3] = np.exp(pred[..., 3]) * anchors[:, 1:2].transpose((1, 0)) / h

  bbox = np.zeros((h * w, 3, 4))
  bbox[..., 0] = np.maximum((pred[..., 0] - pred[..., 2] / 2.0 - shift_x_ratio) * x_scale * img_w, 0)  # x_min
  bbox[..., 1] = np.maximum((pred[..., 1] - pred[..., 3] / 2.0 - shift_y_ratio) * y_scale * img_h, 0)  # y_min
  bbox[..., 2] = np.minimum((pred[..., 0] + pred[..., 2] / 2.0 - shift_x_ratio) * x_scale * img_w, img_w)  # x_max
  bbox[..., 3] = np.minimum((pred[..., 1] + pred[..., 3] / 2.0 - shift_y_ratio) * y_scale * img_h, img_h)  # y_max

  pred[..., :4] = bbox
  pred = pred.reshape((-1, 5 + class_num))
  pred[:, 4] = pred[:, 4] * pred[:, 5:].max(1)
  pred = pred[pred[:, 4] >= conf_threshold]
  pred[:, 5] = np.argmax(pred[:, 5:], axis=-1)

  all_boxes = [[] for ix in range(class_num)]
  for ix in range(pred.shape[0]):
    box = [int(pred[ix, iy]) for iy in range(4)]
    box.append(int(pred[ix, 5]))
    box.append(pred[ix, 4])
    all_boxes[box[4] - 1].append(box)

  return all_boxes


def get_result(model_outputs, img_w, img_h, new_w, new_h, shift_x_ratio, shift_y_ratio):
  num_channel = 3 * (class_num + 5)
  x_scale = net_w / float(new_w)
  y_scale = net_h / float(new_h)
  all_boxes = [[] for ix in range(class_num)]
  for ix in range(3):
    pred = model_outputs[2 - ix].reshape((num_channel, net_h // stride_list[ix], net_w // stride_list[ix]))
    anchors = anchor_list[ix]
    boxes = decode_bbox(pred, anchors, img_w, img_h, x_scale, y_scale, shift_x_ratio, shift_y_ratio)
    all_boxes = [all_boxes[iy] + boxes[iy] for iy in range(class_num)]

  res = apply_nms(all_boxes, iou_threshold)

  return res


class DemoService(HiaiBaseService):
  def _preprocess(self, data):
    self.input_width = int(image_shape[0])
    self.input_height = int(image_shape[1])
    self.aipp_flag = True
    preprocessed_data = {}
    images = []
    for k, v in data.items():
      for file_name, file_content in v.items():
        input_rgb = Image.open(file_content)

        img_preprocess, self.img_w, self.img_h, self.new_w, self.new_h, \
        self.shift_x_ratio, self.shift_y_ratio = preprocess(input_rgb, self.aipp_flag)
        tensor = NNTensor(img_preprocess)
        images.append(tensor)
    tensor_list = NNTensorList(images)
    preprocessed_data['images'] = tensor_list
    return preprocessed_data

  def _inference(self, data, image_info=None):
    result = {}
    for k, v in data.items():
      result[k] = self.model.proc(v)
    return result

  def _postprocess(self, data):
    result_return = dict()
    for k, v in data.items():
      res = get_result(v, self.img_w, self.img_h, self.new_w, self.new_h, self.shift_x_ratio, self.shift_y_ratio)
    if not res:
      result_return['detection_classes'] = []
      result_return['detection_boxes'] = []
      result_return['detection_scores'] = []
      return result_return
    else:
      new_res = np.array(res)
      picked_boxes = new_res[:, 0:4]
      picked_boxes = picked_boxes[:, [1, 0, 3, 2]]
      picked_classes = self.convert_labels(new_res[:, 4])
      picked_score = new_res[:, 5]
      result_return['detection_classes'] = picked_classes
      result_return['detection_boxes'] = picked_boxes.tolist()
      result_return['detection_scores'] = picked_score.tolist()
      return result_return

  def convert_labels(self, label_list):
    """
        class_names = ['person', 'face']
        :param label_list: [1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.]
        :return: 
        """
    if isinstance(label_list, np.ndarray):
      label_list = label_list.tolist()
    label_names = [class_names[int(index)] for index in label_list]
    return label_names

  def ping(self):
    return

  def signature(self):
    pass
